# hypersense/strategy/greedy_no_fallback.py
from typing import List, Dict, Any
from .greedy_important_first import GreedyImportantFirstStrategy

class GreedyNoFallbackStrategy(GreedyImportantFirstStrategy):
    """
    Same as GIF but disable per-round full-space fallback.
    """

    def run(self) -> List[Dict[str, Any]]:
        total_trials_used = len(self.history)
        param_names = list(self.search_space.keys())

        best_trial = max(self.history, key=lambda t: t["score"])
        current_best_config = dict(best_trial["config"])
        current_best_score = best_trial["score"]

        round_idx = 0
        while total_trials_used < self.max_total_trials:
            configs = [t["config"] for t in self.history]
            scores = [t["score"] for t in self.history]

            if len(self.history) < self.min_trials_for_importance:
                importance = {k: 1.0 for k in self.search_space}
            else:
                importance = self.importance_evaluator(configs, scores)

            sorted_params = sorted(importance.items(), key=lambda x: -x[1])
            sorted_param_names = [x[0] for x in sorted_params]
            groups = [sorted_param_names[i:i+self.top_k] for i in range(0, len(sorted_param_names), self.top_k)]

            group_weights = [sum(importance[p] for p in group) for group in groups]
            total_weight = sum(group_weights) if group_weights else 1.0
            group_trials = [max(1, int(round(self.step_trials * w / total_weight))) for w in group_weights]
            delta = self.step_trials - sum(group_trials)
            if delta > 0 and group_weights:
                group_trials[group_weights.index(max(group_weights))] += delta

            round_has_improvement = False
            remaining_trials = self.max_total_trials - total_trials_used

            for group, budget in zip(groups, group_trials):
                allowed = min(budget, remaining_trials)
                if allowed <= 0:
                    break

                subspace = {k: self.search_space[k] for k in group}
                fixed_config = {k: current_best_config[k] for k in self.search_space if k not in group}
                optimizer = self.optimizer_builder(subspace, self.history, fixed_config, allowed)
                try:
                    optimizer_results = optimizer.optimize()
                except Exception as e:
                    print(f"[GIF-NoFallback] Optimizer failed on group {group}: {e}")
                    continue

                new_trials = []
                for config, result, elapsed_time in optimizer_results:
                    full_config = dict(current_best_config)
                    full_config.update(config)
                    new_trials.append({
                        "config": full_config,
                        "score": result,
                        "elapsed_time": elapsed_time,
                        "round": round_idx,
                        "group": group,
                    })

                self.history.extend(new_trials)
                n_added = len(new_trials)
                total_trials_used += n_added
                remaining_trials -= n_added
                if remaining_trials <= 0:
                    break

                group_best = max(new_trials, key=lambda t: t["score"])
                if group_best["score"] > current_best_score:
                    current_best_score = group_best["score"]
                    current_best_config = dict(group_best["config"])
                    round_has_improvement = True

                self.logs.append({
                    "round": round_idx,
                    "group": group,
                    "trials": len(new_trials),
                    "best_score": group_best["score"],
                    "importance_snapshot": dict(importance),
                    "no_fallback": True,
                })

            round_idx += 1

        self.current_best_config = current_best_config
        self.current_best_score = current_best_score
        return self.history
